library(MASS)
library(haven)
library(Matrix)
library(lfe)
library(ggplot2)
library(RColorBrewer)
library(Hotelling)
library(did)

source("functions.R")

### run data_cleaning.R and classification.R first
load("data_classified.RData")

# construct placeholder matrices for propensity score ratios
Pi <- matrix(0,N,2)
Pi[,1] <- 1
L <- matrix(0,N,1)
stackPi <- array(Pi,c(N,2,T))
stackL <- array(L,c(N,1,T))

# beta, type, r, sd, CI
pre <- 4
post <- 6

### K=2
# Figure 1 of the manuscript
temp <- matrix(0,(pre+post)*2,6)
# type 1 and 2
temp[,2] <- c(rep(1,pre+post), rep(2,pre+post))
temp[,3] <- c((-pre-1):-2,0:(post-1),(-pre-1):-2,0:(post-1))
for (i in 1:(2*(pre+post))){
  temp[i,1] <- estCATT(kmeansK2$gamma,
                       stackPi,
                       stackL,
                       as.matrix(KmeansData[,1:5]),temp[i,3],temp[i,2])$estimate
  temp[i,4] <- sd(estCATT(kmeansK2$gamma,
                          stackPi,
                          stackL,
                          as.matrix(KmeansData[,1:5]),temp[i,3],temp[i,2])$score)/N^(0.5)
}
temp[,5] <- temp[,1] - 1.96*temp[,4]
temp[,6] <- temp[,1] + 1.96*temp[,4]

temp <- as.data.frame(temp)
colnames(temp) <- c("beta", "type", "r", "sd", "lb", "ub")
temp$type <- as.factor(temp$type)

# graph with 95% CI
temp_palette <- brewer.pal(n = 5, name = 'Set1')
ggplot(temp, aes(x=r,y=beta,group=type)) +
  geom_line(aes(color=type),linewidth=1.5) +
  geom_point(aes(color=type), size=2) +
  geom_errorbar(aes(ymin=lb, ymax=ub,color=type), linewidth=0.5, width=.2) +
  scale_colour_manual(values=temp_palette[1:5]) +
  geom_hline(yintercept=0, color = "black") +
  theme_light()
temp

### K=3
# Figure 1 of SA
temp <- matrix(0,(pre+post)*3,6)
temp[,2] <- c(rep(1,pre+post), rep(2,pre+post), rep(3,pre+post))
temp[,3] <- kronecker(c(1,1,1),c((-pre-1):-2,0:(post-1)))
for (i in 1:(3*(pre+post))){
  temp[i,1] <- estCATT(kmeansK3$gamma,
                       stackPi,
                       stackL,
                       as.matrix(KmeansData[,1:5]),temp[i,3],temp[i,2])$estimate
  temp[i,4] <- sd(estCATT(kmeansK3$gamma,
                          stackPi,
                          stackL,
                          as.matrix(KmeansData[,1:5]),temp[i,3],temp[i,2])$score)/N^(0.5)
}
temp[,5] <- temp[,1] - 1.96*temp[,4]
temp[,6] <- temp[,1] + 1.96*temp[,4]

temp <- as.data.frame(temp)
colnames(temp) <- c("beta", "type", "r", "sd", "lb", "ub")
temp$type <- as.factor(temp$type)

# graph with 95% CI
temp_palette <- brewer.pal(n = 5, name = 'Set1')
ggplot(temp, aes(x=r,y=beta,group=type)) +
  geom_line(aes(color=type),linewidth=1.5) +
  geom_point(aes(color=type), size=2) +
  geom_errorbar(aes(ymin=lb, ymax=ub,color=type), linewidth=0.5, width=.2) +
  scale_colour_manual(values=temp_palette[1:5]) +
  geom_hline(yintercept=0, color = "black") +
  theme_light()
temp

### K=4
# Figure 2 of SA
temp <- matrix(0,(pre+post)*4,6)
temp[,2] <- c(rep(1,pre+post), rep(2,pre+post), rep(3,pre+post), rep(4,pre+post))
temp[,3] <- kronecker(c(1,1,1,1),c((-pre-1):-2,0:(post-1)))
for (i in 1:(4*(pre+post))){
  temp[i,1] <- estCATT(kmeansK4$gamma,
                       stackPi,
                       stackL,
                       as.matrix(KmeansData[,1:5]),temp[i,3],temp[i,2])$estimate
  temp[i,4] <- sd(estCATT(kmeansK4$gamma,
                          stackPi,
                          stackL,
                          as.matrix(KmeansData[,1:5]),temp[i,3],temp[i,2])$score)/N^(0.5)
}
temp[,5] <- temp[,1] - 1.96*temp[,4]
temp[,6] <- temp[,1] + 1.96*temp[,4]

temp <- as.data.frame(temp)
colnames(temp) <- c("beta", "type", "r", "sd", "lb", "ub")
temp$type <- as.factor(temp$type)

# graph with 95% CI
temp_palette <- brewer.pal(n = 5, name = 'Set1')
ggplot(temp, aes(x=r,y=beta,group=type)) +
  geom_line(aes(color=type),linewidth=1.5) +
  geom_point(aes(color=type), size=2) +
  geom_errorbar(aes(ymin=lb, ymax=ub,color=type), linewidth=0.5, width=.2) +
  scale_colour_manual(values=temp_palette[1:5]) +
  geom_hline(yintercept=0, color = "black") +
  theme_light()
temp


### using never-treated units only
# Figure 3 of SA
temp <- matrix(0,(pre+post)*2,6)
# type 1 and 2
temp[,2] <- c(rep(1,pre+post), rep(2,pre+post))
temp[,3] <- c((-pre-1):-2,0:(post-1),(-pre-1):-2,0:(post-1))
for (i in 1:(2*(pre+post))){
  temp[i,1] <- estCATT(kmeansNT$gamma,
                       stackPi,
                       stackL,
                       as.matrix(KmeansData[,1:5]),temp[i,3],temp[i,2])$estimate
  temp[i,4] <- sd(estCATT(kmeansNT$gamma,
                          stackPi,
                          stackL,
                          as.matrix(KmeansData[,1:5]),temp[i,3],temp[i,2])$score)/N^(0.5)
}
temp[,5] <- temp[,1] - 1.96*temp[,4]
temp[,6] <- temp[,1] + 1.96*temp[,4]

temp <- as.data.frame(temp)
colnames(temp) <- c("beta", "type", "r", "sd", "lb", "ub")
temp$type <- as.factor(temp$type)

# graph with 95% CI
temp_palette <- brewer.pal(n = 5, name = 'Set1')
ggplot(temp, aes(x=r,y=beta,group=type)) +
  geom_line(aes(color=type),linewidth=1.5) +
  geom_point(aes(color=type), size=2) +
  geom_errorbar(aes(ymin=lb, ymax=ub,color=type), linewidth=0.5, width=.2) +
  scale_colour_manual(values=temp_palette[1:5]) +
  geom_hline(yintercept=0, color = "black") +
  theme_light()
temp

### Callaway and Sant'anna 
# Figure 1 of the manuscript
CSAdata <- KmeansData
CSAdata$p_white_init <- sapply(CSAdata$id,
                               function(i) CSAdata$p_white[CSAdata$time==0 & 
                                                             CSAdata$id==i]) 
CSAdata$p_hisp_init <- sapply(CSAdata$id,
                              function(i) CSAdata$p_hisp[CSAdata$time==0 & 
                                                           CSAdata$id==i]) 
CSAdata$n_student_init <- sapply(CSAdata$id,
                                 function(i) CSAdata$n_student[CSAdata$time==0 & 
                                                                 CSAdata$id==i]) 

CSAcatt <- att_gt(yname="disd",
                  tname="time",
                  idname="id",
                  gname="E",
                  xformla= ~ ccbase + p_white_init,
                  weightsname="weight",
                  control_group="notyettreated",
                  data=CSAdata)
CSAatt <- aggte(CSAcatt, type="dynamic", na.rm=TRUE)
CSAatt